iT邦幫忙

2023 iThome 鐵人賽

DAY 13
0
AI & Data

紮實的ML機器學習原理~打造你對資料使用sklearn的靈敏度系列 第 13

DAY 13 「隨機搜索(Random Search)& Halving 網格搜索(Halving Grid Search)」最佳超參數組合來做鳶尾花分類啦~

  • 分享至 

  • xImage
  •  

模型調優最佳幫手演算法~~

/images/emoticon/emoticon08.gif白話來說可以幫助找到模型的最佳超參數組合~~

  • 隨機搜索(Random Search):
    隨機搜索是一種簡單但高效的超參數調優方法,在超參數空間中隨機選擇一組超參數進行訓練和評估。相對於網格搜索,隨機搜索可以更快地找到較好的超參數組合,特別是當超參數的數量很大時

基本原理:
定義超參數空間:確定需要調優的超參數以及其取值範圍。
隨機選擇超參數:在超參數空間中隨機選擇一組超參數。
訓練模型:使用選定的超參數在訓練集上訓練模型。
評估性能:使用驗證集或交叉驗證對模型性能進行評估。
重覆步驟 2 到 4:重覆多次,直到達到預定的訓練次數或時間。

from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
import numpy as np

# 載入鳶尾花數據集
iris = load_iris()
X, y = iris.data, iris.target

# 定義超參數空間
param_dist = {
    'n_estimators': [50, 100, 200],
    'max_depth': [3, 5, 10],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}

# 初始化隨機搜索對象
random_search = RandomizedSearchCV(
    estimator=RandomForestClassifier(),
    param_distributions=param_dist,
    n_iter=10,  # 隨機搜索次數
    cv=5,  # 交叉驗證的折數
    random_state=42
)

# 執行隨機搜索
random_search.fit(X, y)

# 輸出結果
print(f'隨機搜索最佳參數:{random_search.best_params_}')
print(f'隨機搜索最佳分數:{random_search.best_score_}')
  • Halving 網格搜索(Halving Grid Search):
    Halving 網格搜索是一種結合了網格搜索和隨機搜索的方法,通過在每一步中隨機選取一部分參數進行訓練和評估,從而減少了搜索空間,提高了搜索效率

基本原理:
定義超參數空間:確定需要調優的超參數以及其取值範圍。
隨機選擇超參數子集:在超參數空間中隨機選擇一部分超參數。
訓練模型:使用選定的超參數子集在訓練集上訓練模型。
評估性能:使用驗證集或交叉驗證對模型性能進行評估。
淘汰低性能模型:根據性能,淘汰表現較差的模型。
縮小超參數空間:將表現較好的模型對應的超參數子集作為下一輪搜索的候選。
重覆步驟 2 到 6:重覆多次,直到達到預定的訓練次數或時間。

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV, HalvingRandomSearchCV
from sklearn.datasets import load_iris

# 載入鳶尾花數據集
iris = load_iris()
X, y = iris.data, iris.target

# 定義超參數空間
param_dist = {
    'n_estimators': [50, 100, 200],
    'max_depth': [3, 5, 10],
    'min_samples_split': [2, 5, 10]
}

# 隨機搜索
random_search = RandomizedSearchCV(
    RandomForestClassifier(),
    param_distributions=param_dist,
    n_iter=10,  # 隨機搜索次數
    cv=5,  # 交叉驗證的折數
    random_state=42
)
random_search.fit(X, y)
print(f'隨機搜索最佳參數:{random_search.best_params_}')

# Halving 網格搜索
halving_search = HalvingRandomSearchCV(
    RandomForestClassifier(),
    param_distributions=param_dist,
    factor=3,  # 每一輪淘汰的比例
    resource='n_samples',  # 指定資源(樣本數)來決定淘汰
    max_resources=100,  # 最大使用資源(樣本數)
    cv=5,  # 交叉驗證的折數
    random_state=42
)
halving_search.fit(X, y)
print(f'Halving 網格搜索最佳參數:{halving_search.best_params_}')

上一篇
DAY 12 「集成學習(Ensemble Learning)Boosting 和 AdaBoost」AI專案落地必有的觀念啦~
下一篇
DAY 14 「梯度提升樹(Gradient Boosting Decision Trees / GBDT)」集成來做鳶尾花分類啦~
系列文
紮實的ML機器學習原理~打造你對資料使用sklearn的靈敏度30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言